import torch
import numpy as np

from dada.optimizer import WDA, UGM, DADA, DoG, Prodigy

from dada.model import ModelRunner
from dada.model.log_sum_exp import LogSumExpModel
from dada.utils import Param, run_different_opts, plot_optimizers_result


class LogSumExpRunner(ModelRunner):

    def __init__(self, params: dict):
        self.num_polyhedron = params['num_polyhedron']
        if self.num_polyhedron is None:
            raise ValueError('num_polyhedron is None')

        self.mu_list = params['mu_list']
        if self.mu_list is None:
            mu = params['mu']
            if mu is None:
                raise ValueError('mu is None')
            self.mu_list = [mu]

        super(LogSumExpRunner, self).__init__(params)

    def run(self, iterations, model_name, save_plot, plots_directory):
        params = [
            Param(names=["n", "d", "\mu"], values=[self.num_polyhedron, self.vector_size, mu])
            for mu in self.mu_list
        ]
        value_distances_per_param = {}
        d_estimation_error_per_param = {}

        optimal_point = np.zeros((self.vector_size,))

        optimizers = []

        for param in params:
            print(param)
            n = param.get_param("n")
            d = param.get_param("d")
            mu = param.get_param("\mu")

            init = torch.ones(self.vector_size, requires_grad=True, dtype=torch.double)
            a_matrix, b_matrix = LogSumExpModel.generate_function_variables(d, n, mu)
            d0 = np.linalg.norm(optimal_point - init.clone().detach().numpy())

            # Dual Averaging Method
            da_model = LogSumExpModel(d, n, a_matrix, b_matrix, mu, init_point=init)
            da_optimizer = WDA(da_model.params(), d0=d0)

            # GD With Line Search Method
            gd_line_search_model = LogSumExpModel(d, n, a_matrix, b_matrix, mu, init_point=init)
            gd_line_search_optimizer = UGM(gd_line_search_model.params())

            # DoG Method
            dog_model = LogSumExpModel(d, n, a_matrix, b_matrix, mu, init_point=init)
            dog_optimizer = DoG(dog_model.params())

            # Prodigy Method
            prodigy_model = LogSumExpModel(d, n, a_matrix, b_matrix, mu, init_point=init)
            prodigy_optimizer = Prodigy(prodigy_model.params())

            # DADA Method
            dada_model = LogSumExpModel(d, n, a_matrix, b_matrix, mu, init_point=init)
            dada_optimizer = DADA(dada_model.params())

            optimizers = [
                [da_optimizer, da_model],
                [gd_line_search_optimizer, gd_line_search_model],
                [dog_optimizer, dog_model],
                [prodigy_optimizer, prodigy_model],
                [dada_optimizer, dada_model]
            ]

            d_estimation_error, value_distances = run_different_opts(optimizers, iterations, optimal_point, log_per=(iterations // 10))
            value_distances_per_param[param] = value_distances
            d_estimation_error_per_param[param] = d_estimation_error

        plot_optimizers_result(optimizers, params, value_distances_per_param, d_estimation_error_per_param,
                               model_name=model_name, save=save_plot, plots_directory=plots_directory,
                               mark_every=(iterations // 10))
